/**
 * 
 */
package edu.umd.clip.lm.tools;

import java.io.*;
import java.nio.channels.*;
import java.util.*;

import edu.berkeley.nlp.util.*;
import edu.umd.clip.lm.factors.*;
import edu.umd.clip.lm.factors.Dictionary;
import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.model.data.*;
import edu.umd.clip.lm.ngram.Ngram;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class TrainingDataNgramCounts {
	public static class Options {
        @Option(name = "-data-files", required = true, usage = "Training data files (comma-separated)")
		public String dataFiles;
        @Option(name = "-config", required = true, usage = "XML config file")
		public String config;
        @Option(name = "-order", required = false, usage = "order of n-grams (default 3)")
		public int order = 3;
        @Option(name = "-tags", required = false, usage = "don't mask tags out")
		public boolean tags;
	}

	/**
	 * @param args
	 * @throws IOException 
	 */
	@SuppressWarnings("unchecked")
	public static void main(String[] args) throws IOException {
        OptionParser optParser = new OptionParser(Options.class);
        final Options opts = (Options) optParser.parse(args, true);

		Experiment.initialize(opts.config);
		Experiment experiment = Experiment.getInstance();
		experiment.buildPrefixes();

		final long overtMask = opts.tags ? -1L : experiment.getTupleDescription().getOvertFactorsMask();

		HashMap<Ngram,MutableInteger> ngramCounts = new HashMap<Ngram,MutableInteger>(10000000);
		
		String [] fileNames = opts.dataFiles.split(",");
		
		// collect singleton stats
		for(String fileName : fileNames) {
			ReadableByteChannel channel = new FileInputStream(fileName).getChannel();
			TrainingDataReader reader = new OnDiskTrainingDataReader(channel);
			ReadableTrainingData inputData = new ReadableTrainingData(reader);

			while(inputData.hasNext()) {
				TrainingDataBlock block = inputData.next();
				
				for(ContextFuturesPair pair : block) {
					long [] context = pair.getContext().data;
					for(int i=0; i<context.length; ++i) {
						context[i] &= overtMask;
					}
					
					for(TupleCountPair tc : pair.getFutures()) {
						Ngram ngram = Ngram.makeNgram(context, tc.tuple & overtMask);

						MutableInteger count = ngramCounts.get(ngram);
						if (count == null) {
							ngramCounts.put(ngram, new MutableInteger(tc.count));
						} else {
							count.add(tc.count);
						}
					}
					
				}
			}
			channel.close();
		}
		
		if (ngramCounts.size() == 0) return;
		int order = ngramCounts.keySet().iterator().next().getOrder();
		//System.err.printf("order=%d\n", order);
		
		HashMap<Ngram,MutableInteger> countsByOrder[] = new HashMap[order];
		countsByOrder[countsByOrder.length-1] = ngramCounts;
		
		for(int i=countsByOrder.length-2; i>=0; --i) {
			HashMap<Ngram,MutableInteger> prevCounts = countsByOrder[i+1];
			
			countsByOrder[i] = new HashMap<Ngram,MutableInteger>((int) (Math.pow(prevCounts.size(), 0.7) + 1));
			
			for(Map.Entry<Ngram, MutableInteger> e : prevCounts.entrySet()) {
				Ngram ngram = e.getKey().getLowerOrderNgram();
				int count = e.getValue().intValue();
				
				MutableInteger newCount = countsByOrder[i].get(ngram);
				if (newCount == null) {
					newCount = new MutableInteger();
					countsByOrder[i].put(ngram, newCount);
				}
				newCount.add(count);
			}
		}
		                              
		byte wordIdx = experiment.getTupleDescription().getMainFactorIndex();
		Dictionary dict = experiment.getTupleDescription().getDictionary(wordIdx);
		
		for(HashMap<Ngram,MutableInteger> counts : countsByOrder) {
			for(Map.Entry<Ngram, MutableInteger> entry : counts.entrySet()) {
				Ngram ngram = entry.getKey();
				int count = entry.getValue().intValue();
				
				for(long word : ngram.getContext()) {
					if (opts.tags) {
						System.out.print(FactorTuple.toStringNoNull(word));
					} else {
						System.out.print(dict.getWord(FactorTuple.getValue(word, wordIdx)));
					}
					System.out.print(' ');
				}
				System.out.print(dict.getWord(FactorTuple.getValue(ngram.getFuture(), wordIdx)));
				System.out.print('\t');
				System.out.println(count);
			}
		}
	}
}
